{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": [
        "AkmZkTv6WW7s"
      ],
      "machine_shape": "hm"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "##### Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0"
      ],
      "metadata": {
        "id": "AkmZkTv6WW7s"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2023 Google LLC. SPDX-License-Identifier: Apache-2.0\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ],
      "metadata": {
        "id": "I36DZu2LWZkE"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "# **Robots That Ask For Help: Uncertainty Alignment for Large Language Model Planners** Demo\n",
        "\n",
        "[KnowNo](https://robot-help.github.io) is a framework for measuring and aligning the uncertainty of LLM-based planners, such that they know when they don't know, and ask for help when needed. KnowNo builds on the theory of conformal prediction to provide statistical guarantees on task completion while minimizing human help.\n",
        "\n",
        "This colab shows the very basics of constructing the prediction set (possible actions in a scenario) in the Mobile Manipulation setting. The left side of the figure belore shows a sample scenario.\n",
        "\n",
        "<img src=\"https://robot-help.github.io/img/robot-help-teaser.png\" height=\"280px\">\n",
        "\n",
        "Note:\n",
        "* Instead of setting up the scenario distribution here, we will load a dataset sampled from a pre-defined scenario distribution involving the mobile robot, the same used in the experiments. We will also use calibration results already computed with the distribution.\n",
        "* We use [GPT-3.5](https://arxiv.org/abs/2005.14165) (text-davinci-003) as the language model here.\n",
        "* We focus on the planning part; we do not consider object detection or low-level action execution here.\n",
        "\n",
        "Disclaimer: We fine the GPT3.5 model significantly underperforms [PaLM2-L](https://ai.google/discover/palm2/) model used in our experiments, largely due to its bias towards option C and D over option A and B in multiple choice question answering. We also find such bias dependent on the context, so adjusting bias for certain options in the API call does not help significantly."
      ],
      "metadata": {
        "id": "cHsncwrPOxZt"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "openai_api_key = \"your-api-key\""
      ],
      "metadata": {
        "id": "Eycru54hVK9d"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Setup**"
      ],
      "metadata": {
        "id": "x3wvRmWYVPLA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown A few imports\n",
        "!pip install openai tqdm\n",
        "\n",
        "import openai\n",
        "import signal\n",
        "import tqdm.notebook as tqdm\n",
        "import random\n",
        "import numpy as np\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Set OpenAI API key.\n",
        "openai.api_key = openai_api_key"
      ],
      "metadata": {
        "id": "DpR4dgevMMsa",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown LLM API call\n",
        "class timeout:\n",
        "    def __init__(self, seconds=1, error_message='Timeout'):\n",
        "        self.seconds = seconds\n",
        "        self.error_message = error_message\n",
        "\n",
        "    def handle_timeout(self, signum, frame):\n",
        "        raise TimeoutError(self.error_message)\n",
        "\n",
        "    def __enter__(self):\n",
        "        signal.signal(signal.SIGALRM, self.handle_timeout)\n",
        "        signal.alarm(self.seconds)\n",
        "\n",
        "    def __exit__(self, type, value, traceback):\n",
        "        signal.alarm(0)\n",
        "\n",
        "# OpenAI only supports up to five tokens (logprobs argument) for getting the likelihood.\n",
        "# Thus we use the logit_bias argument to force LLM only consdering the five option\n",
        "# tokens: A, B, C, D, E\n",
        "def lm(prompt,\n",
        "       max_tokens=256,\n",
        "       temperature=0,\n",
        "       logprobs=None,\n",
        "       stop_seq=None,\n",
        "       logit_bias={\n",
        "          317: 100.0,   #  A (with space at front)\n",
        "          347: 100.0,   #  B (with space at front)\n",
        "          327: 100.0,   #  C (with space at front)\n",
        "          360: 100.0,   #  D (with space at front)\n",
        "          412: 100.0,   #  E (with space at front)\n",
        "      },\n",
        "       timeout_seconds=20):\n",
        "  max_attempts = 5\n",
        "  for _ in range(max_attempts):\n",
        "      try:\n",
        "          with timeout(seconds=timeout_seconds):\n",
        "              response = openai.Completion.create(\n",
        "                  model='text-davinci-003',\n",
        "                  prompt=prompt,\n",
        "                  max_tokens=max_tokens,\n",
        "                  temperature=temperature,\n",
        "                  logprobs=logprobs,\n",
        "                  logit_bias=logit_bias,\n",
        "                  stop=list(stop_seq) if stop_seq is not None else None,\n",
        "              )\n",
        "          break\n",
        "      except:\n",
        "          print('Timeout, retrying...')\n",
        "          pass\n",
        "  return response, response[\"choices\"][0][\"text\"].strip()"
      ],
      "metadata": {
        "cellView": "form",
        "id": "SZF1j4s_VRdq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Specify the instruction**\n",
        "Consider a setting where there can be a counter with three objects on top it (figure below). There are also a top drawer and a bottom drawer under the counter. There is a set of landfill, recycling, and compost bins next to the counter (not shown).\n",
        "\n",
        "<img src=\"https://robot-help.github.io/img/sample-mobile-manipulation.png\" height=\"200px\">\n",
        "\n",
        "The possible task instruction, for example, can be \"pick up the apple\", \"put the apple in the drawer\" (unclear about the choice of drawer), and \"dispose of the apple\".\n",
        "\n",
        "Besides the apple, orange, and Sprite shown in the image, we have also calibrated the LLM to perform tasks with these objects: bottled water, bottled tea, orange soda, RedBull, Coke, Pepsi, rice chips, jalapeno chips, kettle chips, multigrain chips, energy bar, dirty sponge with food residue, clean sponge, metal bowl, plastic bowl.\n",
        "\n",
        "Now you can specify the task instruction and also the three objects present on the countertop."
      ],
      "metadata": {
        "id": "hLq1cyarUFs3"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "instruction = \"Put the bottled water in the bin.\" #@param {type:\"string\"}\n",
        "scene_objects = \"energy bar, bottled water, rice chips\" #@param {type:\"string\"}"
      ],
      "metadata": {
        "id": "un9GDoSUilg9",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Access the LLM uncertainty**\n",
        "Next, we would like to see how uncertain the LLM is about the correct action to take in this scenario."
      ],
      "metadata": {
        "id": "2lgsoIZ0voJg"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown First, we prompt the LLM to generate possible options with few-shot prompting\n",
        "demo_mc_gen_prompt = \"\"\"\n",
        "We: You are a robot operating in an office kitchen. You are in front of a counter with two closed drawers, a top one and a bottom one. There is also a landfill bin, a recycling bin, and a compost bin.\n",
        "\n",
        "We: On the counter, there is an orange soda, a Pepsi, and an apple.\n",
        "We: Put that drink in the top drawer.\n",
        "You:\n",
        "A) open the top drawer and put the orange soda in it\n",
        "B) open the bottom drawer and put the Pepsi in it\n",
        "C) open the bottom drawer and put the orange soda in it\n",
        "D) open the top drawer and put the Pepsi in it\n",
        "\n",
        "We: On the counter, there is an energy bar, a banana, and a microwave.\n",
        "We: Put the snack next to the microwave.\n",
        "You:\n",
        "A) pick up the energy bar and put it next to the microwave\n",
        "B) pick up the banana and put it next to the energy bar\n",
        "C) pick up the banana and put it next to the microwave\n",
        "D) pick up the energy bar and put it next to the banana\n",
        "\n",
        "We: On the counter, there is a Coke, a Sprite, and a sponge.\n",
        "We: Can you dispose of the can? It should have expired.\n",
        "You:\n",
        "A) pick up the sponge and put it in the landfill bin\n",
        "B) pick up the Coke and put it in the recycling bin\n",
        "C) pick up the Sprite and put it in the recycling bin\n",
        "D) pick up the Coke and put it in the landfill bin\n",
        "\n",
        "We: On the counter, there is a bottled water, a bag of jalapeno chips, and a bag of rice chips.\n",
        "We: I would like a bag of chips.\n",
        "You:\n",
        "A) pick up the bottled water\n",
        "B) pick up the jalapeno chips\n",
        "C) pick up the kettle chips\n",
        "D) pick up the rice chips\n",
        "\n",
        "We: On the counter, there is {scene_objects}\n",
        "We: {task}\n",
        "You:\n",
        "\"\"\"\n",
        "\n",
        "def process_mc_raw(mc_raw, add_mc='an option not listed here'):\n",
        "  mc_all = mc_raw.split('\\n')\n",
        "\n",
        "  mc_processed_all = []\n",
        "  for mc in mc_all:\n",
        "      mc = mc.strip()\n",
        "\n",
        "      # skip nonsense\n",
        "      if len(mc) < 5 or mc[0] not in [\n",
        "          'a', 'b', 'c', 'd', 'A', 'B', 'C', 'D', '1', '2', '3', '4'\n",
        "      ]:\n",
        "          continue\n",
        "      mc = mc[2:]  # remove a), b), ...\n",
        "      mc = mc.strip().lower().split('.')[0]\n",
        "      mc_processed_all.append(mc)\n",
        "  if len(mc_processed_all) < 4:\n",
        "      raise 'Cannot extract four options from the raw output.'\n",
        "\n",
        "  # Check if any repeated option - use do nothing as substitue\n",
        "  mc_processed_all = list(set(mc_processed_all))\n",
        "  if len(mc_processed_all) < 4:\n",
        "      num_need = 4 - len(mc_processed_all)\n",
        "      for _ in range(num_need):\n",
        "          mc_processed_all.append('do nothing')\n",
        "  prefix_all = ['A) ', 'B) ', 'C) ', 'D) ']\n",
        "  if add_mc is not None:\n",
        "      mc_processed_all.append(add_mc)\n",
        "      prefix_all.append('E) ')\n",
        "  random.shuffle(mc_processed_all)\n",
        "\n",
        "  # get full string\n",
        "  mc_prompt = ''\n",
        "  for mc_ind, (prefix, mc) in enumerate(zip(prefix_all, mc_processed_all)):\n",
        "      mc_prompt += prefix + mc\n",
        "      if mc_ind < len(mc_processed_all) - 1:\n",
        "          mc_prompt += '\\n'\n",
        "  add_mc_prefix = prefix_all[mc_processed_all.index(add_mc)][0]\n",
        "  return mc_prompt, mc_processed_all, add_mc_prefix\n",
        "\n",
        "demo_mc_gen_prompt = demo_mc_gen_prompt.replace('{task}', instruction)\n",
        "demo_mc_gen_prompt = demo_mc_gen_prompt.replace('{scene_objects}', scene_objects)\n",
        "\n",
        "# Generate multiple choices\n",
        "_, demo_mc_gen_raw = lm(demo_mc_gen_prompt, stop_seq=['We:'], logit_bias={})\n",
        "demo_mc_gen_raw = demo_mc_gen_raw.strip()\n",
        "demo_mc_gen_full, demo_mc_gen_all, demo_add_mc_prefix = process_mc_raw(demo_mc_gen_raw)\n",
        "\n",
        "print('====== Prompt for generating possible options ======')\n",
        "print(demo_mc_gen_prompt)\n",
        "\n",
        "print('====== Generated options ======')\n",
        "print(demo_mc_gen_full)"
      ],
      "metadata": {
        "cellView": "form",
        "id": "9L0iTMBgsPtm"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@markdown Then we evaluate the probabilities of the LLM predicting each option (A/B/C/D/E)\n",
        "\n",
        "# get the part of the current scenario from the previous prompt\n",
        "demo_cur_scenario_prompt = demo_mc_gen_prompt.split('\\n\\n')[-1].strip()\n",
        "\n",
        "# get new prompt\n",
        "demo_mc_score_background_prompt = \"\"\"\n",
        "You are a robot operating in an office kitchen. You are in front of a counter with two closed drawers, a top one and a bottom one. There is also a landfill bin, a recycling bin, and a compost bin.\n",
        "\"\"\".strip()\n",
        "demo_mc_score_prompt = demo_mc_score_background_prompt + '\\n\\n' + demo_cur_scenario_prompt + '\\n' + demo_mc_gen_full\n",
        "demo_mc_score_prompt += \"\\nWe: Which option is correct? Answer with a single letter.\"\n",
        "demo_mc_score_prompt += \"\\nYou:\"\n",
        "\n",
        "# scoring\n",
        "mc_score_response, _ = lm(demo_mc_score_prompt, max_tokens=1, logprobs=5)\n",
        "top_logprobs_full = mc_score_response[\"choices\"][0][\"logprobs\"][\"top_logprobs\"][0]\n",
        "top_tokens = [token.strip() for token in top_logprobs_full.keys()]\n",
        "top_logprobs = [value for value in top_logprobs_full.values()]\n",
        "\n",
        "print('====== Prompt for scoring options ======')\n",
        "print(demo_mc_score_prompt)\n",
        "\n",
        "print('\\n====== Raw log probabilities for each option ======')\n",
        "for token, logprob in zip(top_tokens, top_logprobs):\n",
        "  print('Option:', token, '\\t', 'log prob:', logprob)"
      ],
      "metadata": {
        "id": "mDlojk5zv0d1",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## **Construct prediction set**\n",
        "With the probabilities from the LLM, we can construct the prediction set now. From calibration, we have determined the threshold to be 0.072 with a target success level of 0.8. This means the calibration set includes all options with softmax score higher than 0.072. Conformal prediction provides guarantee that the correct action is included in the set with 80% probability!\n",
        "\n",
        "When the set has more than one option, we deem the LLM is uncertain about the correct option and **triggers human help**."
      ],
      "metadata": {
        "id": "4pryBm3sxFTh"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title\n",
        "qhat = 0.928\n",
        "\n",
        "# get prediction set\n",
        "def temperature_scaling(logits, temperature):\n",
        "    logits = np.array(logits)\n",
        "    logits /= temperature\n",
        "\n",
        "    # apply softmax\n",
        "    logits -= logits.max()\n",
        "    logits = logits - np.log(np.sum(np.exp(logits)))\n",
        "    smx = np.exp(logits)\n",
        "    return smx\n",
        "mc_smx_all = temperature_scaling(top_logprobs, temperature=5)\n",
        "\n",
        "# include all options with score >= 1-qhat\n",
        "prediction_set = [\n",
        "          token for token_ind, token in enumerate(top_tokens)\n",
        "          if mc_smx_all[token_ind] >= 1 - qhat\n",
        "      ]\n",
        "\n",
        "# print\n",
        "print('Softmax scores:', mc_smx_all)\n",
        "print('Prediction set:', prediction_set)\n",
        "if len(prediction_set) != 1:\n",
        "  print('Help needed!')\n",
        "else:\n",
        "  print('No help needed!')"
      ],
      "metadata": {
        "id": "-oaU8ZXUwkh0"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}